import os import sys import time import glob from model import * from config import * from keras.models import load_model,save_model from keras.layers import Input from keras import optimizers import dataset import config import misc def load_GD(path, compile = False): G_path = os.path.join(path,'Generator.h5') D_path = os.path.join(path,'Discriminator.h5') G = load_model(G_path, compile = compile) D = load_model(D_path, compile = compile) return G,D def save_GD(G,D,path,overwrite = False): os.makedirs(path); G_path = os.path.join(path,'Generator.h5') D_path = os.path.join(path,'Discriminator.h5') save_model(G,G_path,overwrite = overwrite) save_model(D,D_path,overwrite = overwrite) print("Save model to %s"%path) def load_GD_weights(G,D,path, by_name = True): G_path = os.path.join(path,'Generator.h5') D_path = os.path.join(path,'Discriminator.h5') G.load_weights(G_path, by_name = by_name) D.load_weights(D_path, by_name = by_name) return G,D def save_GD_weights(G,D,path): try: os.makedirs(path); G_path = os.path.join(path,'Generator.h5') D_path = os.path.join(path,'Discriminator.h5') G.save_weights(G_path) D.save_weights(D_path) print("Save weights to %s:"%path) except: print("Save model snapshot failed!") def rampup(epoch, rampup_length): if epoch < rampup_length: p = max(0.0, float(epoch)) / float(rampup_length) p = 1.0 - p return math.exp(-p*p*5.0) else: return 1.0 def format_time(seconds): s = int(np.round(seconds)) if s < 60: return '%ds' % (s) elif s < 60*60: return '%dm %02ds' % (s / 60, s % 60) elif s < 24*60*60: return '%dh %02dm %02ds' % (s / (60*60), (s / 60) % 60, s % 60) else: return '%dd %dh %02dm' % (s / (24*60*60), (s / (60*60)) % 24, (s / 60) % 60) def rampdown_linear(epoch, num_epochs, rampdown_length): if epoch >= num_epochs - rampdown_length: return float(num_epochs - epoch) / rampdown_length else: return 1.0 def create_result_subdir(result_dir, run_desc): # Select run ID and create subdir. while True: run_id = 0 for fname in glob.glob(os.path.join(result_dir, '*')): try: fbase = os.path.basename(fname) ford = int(fbase[:fbase.find('-')]) run_id = max(run_id, ford + 1) except ValueError: pass result_subdir = os.path.join(result_dir, '%03d-%s' % (run_id, run_desc)) try: os.makedirs(result_subdir) break except OSError: if os.path.isdir(result_subdir): continue raise print ("Saving results to", result_subdir) return result_subdir def random_latents(num_latents, G_input_shape): return np.random.randn(num_latents, *G_input_shape[1:]).astype(np.float32) def random_labels(num_labels, training_set): return training_set.labels[np.random.randint(training_set.labels.shape[0], size=num_labels)] def wasserstein_loss( y_true, y_pred): return K.mean(y_true * y_pred) def multiple_loss(y_true, y_pred): return K.mean(y_true*y_pred) def mean_loss(y_true, y_pred): return K.mean(y_pred) def load_dataset(dataset_spec=None, verbose=True, **spec_overrides): if verbose: print('Loading dataset...') if dataset_spec is None: dataset_spec = config.dataset dataset_spec = dict(dataset_spec) # take a copy of the dict before modifying it dataset_spec.update(spec_overrides) dataset_spec['h5_path'] = os.path.join(config.data_dir, dataset_spec['h5_path']) if 'label_path' in dataset_spec: dataset_spec['label_path'] = os.path.join(config.data_dir, dataset_spec['label_path']) training_set = dataset.Dataset(**dataset_spec) if verbose: print('Dataset shape =', np.int32(training_set.shape).tolist()) drange_orig = training_set.get_dynamic_range() if verbose: print('Dynamic range =', drange_orig) return training_set, drange_orig speed_factor = 20 def train_gan( separate_funcs = False, D_training_repeats = 1, G_learning_rate_max = 0.0010, D_learning_rate_max = 0.0010, G_smoothing = 0.999, adam_beta1 = 0.0, adam_beta2 = 0.99, adam_epsilon = 1e-8, minibatch_default = 16, minibatch_overrides = {}, rampup_kimg = 40/speed_factor, rampdown_kimg = 0, lod_initial_resolution = 4, lod_training_kimg = 400/speed_factor, lod_transition_kimg = 400/speed_factor, total_kimg = 10000/speed_factor, dequantize_reals = False, gdrop_beta = 0.9, gdrop_lim = 0.5, gdrop_coef = 0.2, gdrop_exp = 2.0, drange_net = [-1,1], drange_viz = [-1,1], image_grid_size = None, tick_kimg_default = 50/speed_factor, tick_kimg_overrides = {32:20, 64:10, 128:10, 256:5, 512:2, 1024:1}, image_snapshot_ticks = 1, network_snapshot_ticks = 4, image_grid_type = 'default', #resume_network = '000-celeba/network-snapshot-000488', resume_network = None, resume_kimg = 0.0, resume_time = 0.0): training_set, drange_orig = load_dataset() if resume_network: print("Resuming weight from:"+resume_network) G = Generator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.G) D = Discriminator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.D) G,D = load_GD_weights(G,D,os.path.join(config.result_dir,resume_network),True) else: G = Generator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.G) D = Discriminator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.D) G_train,D_train = PG_GAN(G,D,config.G['latent_size'],0,training_set.shape[1],training_set.shape[3]) print(G.summary()) print(D.summary()) # Misc init. resolution_log2 = int(np.round(np.log2(G.output_shape[2]))) initial_lod = max(resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0) cur_lod = 0.0 min_lod, max_lod = -1.0, -2.0 fake_score_avg = 0.0 G_opt = optimizers.Adam(lr = 0.0,beta_1=adam_beta1,beta_2=adam_beta2,epsilon = adam_epsilon) D_opt = optimizers.Adam(lr = 0.0,beta_1 = adam_beta1,beta_2 = adam_beta2,epsilon = adam_epsilon) if config.loss['type']=='wass': G_loss = wasserstein_loss D_loss = wasserstein_loss elif config.loss['type']=='iwass': G_loss = multiple_loss D_loss = [mean_loss,'mse'] D_loss_weight = [1.0, config.loss['iwass_lambda']] G.compile(G_opt,loss=G_loss) D.trainable = False G_train.compile(G_opt,loss = G_loss) D.trainable = True D_train.compile(D_opt,loss=D_loss,loss_weights=D_loss_weight) cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() tick_train_out = [] train_start_time = tick_start_time - resume_time if image_grid_type == 'default': if image_grid_size is None: w, h = G.output_shape[1], G.output_shape[2] print("w:%d,h:%d"%(w,h)) image_grid_size = np.clip(int(1920 // w), 3, 16).astype('int'), np.clip(1080 / h, 2, 16).astype('int') print("image_grid_size:",image_grid_size) example_real_images, snapshot_fake_labels = training_set.get_random_minibatch_channel_last(np.prod(image_grid_size), labels=True) snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape) else: raise ValueError('Invalid image_grid_type', image_grid_type) result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc) print("example_real_images.shape:",example_real_images.shape) misc.save_image_grid(example_real_images, os.path.join(result_subdir, 'reals.png'), drange=drange_orig, grid_size=image_grid_size) snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape) snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents) misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size) nimg_h = 0 while cur_nimg < total_kimg * 1000: # Calculate current LOD. cur_lod = initial_lod if lod_training_kimg or lod_transition_kimg: tlod = (cur_nimg / (1000.0/speed_factor)) / (lod_training_kimg + lod_transition_kimg) cur_lod -= np.floor(tlod) if lod_transition_kimg: cur_lod -= max(1.0 + (np.fmod(tlod, 1.0) - 1.0) * (lod_training_kimg + lod_transition_kimg) / lod_transition_kimg, 0.0) cur_lod = max(cur_lod, 0.0) # Look up resolution-dependent parameters. cur_res = 2 ** (resolution_log2 - int(np.floor(cur_lod))) minibatch_size = minibatch_overrides.get(cur_res, minibatch_default) tick_duration_kimg = tick_kimg_overrides.get(cur_res, tick_kimg_default) # Update network config. lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg) lrate_coef *= rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg) K.set_value(G.optimizer.lr, np.float32(lrate_coef * G_learning_rate_max)) K.set_value(G_train.optimizer.lr, np.float32(lrate_coef * G_learning_rate_max)) K.set_value(D_train.optimizer.lr, np.float32(lrate_coef * D_learning_rate_max)) if hasattr(G_train, 'cur_lod'): K.set_value(G_train.cur_lod,np.float32(cur_lod)) if hasattr(D_train, 'cur_lod'): K.set_value(D_train.cur_lod,np.float32(cur_lod)) new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(np.ceil(cur_lod)) if min_lod != new_min_lod or max_lod != new_max_lod: min_lod, max_lod = new_min_lod, new_max_lod # train D d_loss = None for idx in range(D_training_repeats): mb_reals, mb_labels = training_set.get_random_minibatch_channel_last(minibatch_size, lod=cur_lod, shrink_based_on_lod=True, labels=True) mb_latents = random_latents(minibatch_size,G.input_shape) mb_labels_rnd = random_labels(minibatch_size,training_set) if min_lod > 0: # compensate for shrink_based_on_lod mb_reals = np.repeat(mb_reals, 2**min_lod, axis=1) mb_reals = np.repeat(mb_reals, 2**min_lod, axis=2) mb_fakes = G.predict_on_batch([mb_latents]) epsilon = np.random.uniform(0, 1, size=(minibatch_size,1,1,1)) interpolation = epsilon*mb_reals + (1-epsilon)*mb_fakes mb_reals = misc.adjust_dynamic_range(mb_reals, drange_orig, drange_net) d_loss, d_diff, d_norm = D_train.train_on_batch([mb_fakes, mb_reals, interpolation], [np.ones((minibatch_size, 1,1,1)),np.ones((minibatch_size, 1))]) d_score_real = D.predict_on_batch(mb_reals) d_score_fake = D.predict_on_batch(mb_fakes) print("real score: %d fake score: %d"%(np.mean(d_score_real),np.mean(d_score_fake))) cur_nimg += minibatch_size #train G mb_latents = random_latents(minibatch_size,G.input_shape) mb_labels_rnd = random_labels(minibatch_size,training_set) g_loss = G_train.train_on_batch([mb_latents], (-1)*np.ones((mb_latents.shape[0],1,1,1))) print ("%d [D loss: %f] [G loss: %f]" % (cur_nimg, d_loss,g_loss)) fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0) fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (1.0 - gdrop_beta) gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0) ** gdrop_exp) if hasattr(D, 'gdrop_strength'): K.set_value(D.gdrop_strength,np.float32(gdrop_strength)) if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time tick_start_time = cur_time tick_train_avg = tuple(np.mean(np.concatenate([np.asarray(v).flatten() for v in vals])) for vals in zip(*tick_train_out)) tick_train_out = [] # Visualize generated images. if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents) misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size) if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: save_GD_weights(G,D,os.path.join(result_subdir, 'network-snapshot-%06d' % (cur_nimg / 1000))) save_GD(G,D,os.path.join(result_subdir, 'network-final')) training_set.close() print('Done.') if __name__ == '__main__': np.random.seed(config.random_seed) func_params = config.train func_name = func_params['func'] del func_params['func'] globals()[func_name](**func_params)